from typing import Optional

import wandb

from centralized_verification.configuration import Configuration, TestConfiguration
from centralized_verification.paths import CHECKPOINT_DIR
from centralized_verification.rolling_stat_tracker import RollingStatTracker
from centralized_verification.training_state import TrainingState
from centralized_verification.utils import prefix_dict, TrainingProgress


def train_loop(config: Configuration, training_state: Optional[TrainingState] = None):
    if training_state:
        training_progress = TrainingProgress(global_step_count=training_state.global_step_count,
                                             global_episode_count=training_state.global_episode_count)
    else:
        training_progress = TrainingProgress()

    ep_reward_tracker = [RollingStatTracker(10, 0) for _ in range(config.learner.num_agents())]
    num_unsafe_actions_tracker = RollingStatTracker(10, 0)
    episode_length_tracker = RollingStatTracker(10, 0)

    while not config.limits.should_stop_training(training_progress):
        current_state, joint_obs = config.env.initial_state()
        shield_state = config.shield.get_initial_shield_state(current_state, joint_obs)
        end_of_episode = False
        sum_rew = [0] * config.learner.num_agents()
        num_unsafe_actions = 0
        step_num = 0
        training_progress.global_episode_count += 1

        while not end_of_episode:
            training_progress.global_step_count += 1
            step_num += 1
            joint_action = config.learner.get_joint_action(joint_obs, training_progress)

            # Shields may choose to ignore current_state to simulate a lack of communication
            shield_result, shield_state = config.shield.evaluate_joint_action(current_state, joint_obs, joint_action,
                                                                              shield_state)

            shield_replaced_joint_action = [agent_result.real_action.action for agent_result in shield_result]

            next_state, next_joint_obs, rewards, done, is_safe = config.env.step(current_state,
                                                                                 shield_replaced_joint_action)
            sum_rew = [r1 + r2 for r1, r2 in zip(sum_rew, rewards)]

            config.learner.observe_transition(joint_obs, shield_result, next_joint_obs, rewards, done,
                                              step_num, training_progress)
            current_state = next_state
            joint_obs = next_joint_obs
            if not is_safe:
                num_unsafe_actions += 1

            end_of_episode = done or (
                    config.limits.max_episode_len is not None and step_num == config.limits.max_episode_len)

            if config.limits.is_at_logging_interval(training_progress, config.num_log_entries, end_of_episode):
                log_dict = {"train/num_unsafe_actions": num_unsafe_actions_tracker.average(),
                            "train/sum_unsafe_actions": num_unsafe_actions_tracker.total_sum,
                            "train/episode_count": training_progress.global_episode_count,
                            "train/step_count": training_progress.global_step_count,
                            "train/episode_length": episode_length_tracker.average()}

                learner_log_dict = config.learner.get_log_dict()
                log_dict.update(prefix_dict(learner_log_dict, "train/learner/"))

                for i, ert in enumerate(ep_reward_tracker):
                    log_dict[f"train/episode_rewards_{i}"] = ert.average()

                wandb.log(log_dict)

            if config.limits.is_at_logging_interval(training_progress, config.num_checkpoints, end_of_episode):
                state = TrainingState(global_step_count=training_progress.global_step_count,
                                      global_episode_count=training_progress.global_episode_count,
                                      learner_state_dict=config.learner.state_dict())

                state.save(f"{CHECKPOINT_DIR}/{config.run_name}.pt")

        for ert, sr in zip(ep_reward_tracker, sum_rew):
            ert.track(sr)

        num_unsafe_actions_tracker.track(num_unsafe_actions)
        episode_length_tracker.track(step_num)

        state = TrainingState(global_step_count=training_progress.global_step_count,
                              global_episode_count=training_progress.global_episode_count,
                              learner_state_dict=config.learner.state_dict())

        state.save(f"{CHECKPOINT_DIR}/{config.run_name}.pt")


def test_loop(config: TestConfiguration):
    global_sum_rew = [0] * config.agent.num_agents()
    global_sum_unsafe_actions = 0

    for test_ep_num in range(config.limits.num_episodes):
        current_state, joint_obs = config.env.initial_state()
        shield_state = config.shield.get_initial_shield_state(current_state, joint_obs)
        done = False
        sum_rew = [0] * config.agent.num_agents()
        num_unsafe_actions = 0
        step_num = 0

        while not done and step_num < config.limits.max_episode_len:
            step_num += 1
            joint_action = config.agent.get_joint_action(joint_obs, None)
            shield_result, shield_state = config.shield.evaluate_joint_action(current_state, joint_obs, joint_action,
                                                                              shield_state)

            shield_replaced_joint_action = [agent_result.real_action.action for agent_result in shield_result]

            next_state, next_joint_obs, rewards, done, is_safe = config.env.step(current_state,
                                                                                 shield_replaced_joint_action)
            sum_rew = [r1 + r2 for r1, r2 in zip(sum_rew, rewards)]

            current_state = next_state
            joint_obs = next_joint_obs
            if not is_safe:
                num_unsafe_actions += 1

        log_dict = {"test/num_unsafe_actions": num_unsafe_actions,
                    "test/episode_count": test_ep_num,
                    "test/episode_length": step_num}

        for i, sr in enumerate(sum_rew):
            log_dict[f"test/episode_rewards_{i}"] = sr

        wandb.log(log_dict)

        global_sum_rew = [r1 + r2 for r1, r2 in zip(sum_rew, global_sum_rew)]
        global_sum_unsafe_actions += num_unsafe_actions

    log_dict = {}
    for i, sr in enumerate(global_sum_rew):
        log_dict[f"test/avg_rew_sum_{i}"] = float(sr) / config.limits.num_episodes

    log_dict[f"test/avg_unsafe_actions"] = float(global_sum_unsafe_actions) / config.limits.num_episodes

    wandb.log(log_dict)
